#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import bigdl.nn.layer as BLayer
import bigdl.util.common as BCommon
from bigdl.util.common import get_activation_by_name
from keras.models import model_from_json
from keras.models import Sequential, Model, Layer
import keras
import warnings
from bigdl.keras.ToBigDLHelper import *


def unsupport_exp(name):
    raise Exception("We don't support %s for now" % name)


class WeightLoader:

    # TODO: add more unitest
    # bmodel and kmodel should have the same layers.
    # and this method should only be called when bmodel is generated by kmodel
    @staticmethod
    def load_weights_from_kmodel(bmodel, kmodel):
        keras_name_to_layer = WeightLoader.__keras_name_to_Layers(kmodel, with_weights=True)
        bigdl_name_to_layer = WeightLoader.__bigdl_name_to_Layers(bmodel, with_weights=True)
        # klayer should be just a layer, not seq, not Model
        for klayer in keras_name_to_layer.values():
            if klayer.name in bigdl_name_to_layer:
                blayer = bigdl_name_to_layer[klayer.name]
                bigdl_weights = WeightsConverter.get_bigdl_weights_from_klayer(klayer)
                blayer.set_weights(bigdl_weights)
                if isinstance(klayer, keras.layers.BatchNormalization):
                    blayer.set_running_mean(keras.backend.eval(klayer.running_mean))
                    blayer.set_running_std(keras.backend.eval(klayer.running_std))
            else:
                raise Exception("should not enter here, klayer: %s", klayer)

    @staticmethod
    def load_weights_from_json_hdf5(def_json, weights_hdf5, by_name=False):
        """
        The file path can be stored in a local file system, HDFS, S3,
        or any Hadoop-supported file system.
        """
        bmodel = DefinitionLoader.from_json_path(def_json)
        def_value = BCommon.text_from_path(def_json)
        kmodel = model_from_json(def_value)
        WeightLoader.load_weights_from_hdf5(bmodel, kmodel, weights_hdf5, by_name)
        return bmodel


    @staticmethod
    def load_weights_from_hdf5(bmodel, kmodel, filepath, by_name=False):
        '''Loads all layer weights from a HDF5 save file.
        filepath can be stored in a local file system, HDFS, S3,
        or any Hadoop-supported file system.
        If `by_name` is False (default) weights are loaded
        based on the network's execution order topology,
        meaning layers in the execution seq should be exactly the same
        the architecture

        If `by_name` is True, weights are loaded into layers
        only if they share the same name. This is useful
        for fine-tuning or transfer-learning models where
        some of the layers have changed.
        '''
        local_file_path = BCommon.get_local_file(filepath)
        kmodel.load_weights(filepath=local_file_path, by_name=by_name)
        WeightLoader.load_weights_from_kmodel(bmodel, kmodel)

    @staticmethod
    def __keras_name_to_Layers(model, with_weights=False):
        total_layers = DefinitionLoader(model).node_id_to_layer.values()

        if with_weights:
            layers = [l for l in total_layers
                      if l.get_weights() and not isinstance(l, Model) and not isinstance(l, Sequential)]  # noqa
        else:
            layers = [l for l in total_layers if not isinstance(l, Model) and not isinstance(l, Sequential)]  # noqa

        return dict([(layer.name, layer) for layer in layers])

    @staticmethod
    def __bigdl_name_to_Layers(model, with_weights=False):
        # NB: Container in BigDL is_with_weights() is true if one of the nested layer with_weights
        # but in Keras container get_weights() return false even if the nested layer with_weights
        all_layers = model.flattened_layers(include_container=True)
        if with_weights:
            layers = [l for l in all_layers if l.is_with_weights()]
        else:
            layers = all_layers

        return dict([(layer.name(), layer) for layer in layers])


class WeightsConverter:
    """
    Convert keras weights to bigdl weights
    The shape of weights would be changed if using different backend,
    so we only test against TensorFlow backend.
    TODO: Support th backend as well.
    """

    @staticmethod
    def get_converter(class_name):
        function_name = "convert_" + class_name.lower()
        if not hasattr(WeightsConverter, function_name):
            raise unsupport_exp(class_name)
        converter = getattr(WeightsConverter, function_name)
        return converter

    @staticmethod
    # weights is a list of ndarray or a ndarray
    # convert keras weights per layer to bigdl format
    def to_bigdl_weights(klayer, weights):
        return WeightsConverter.get_converter(klayer.__class__.__name__)(klayer, weights)

    @staticmethod
    def get_bigdl_weights_from_klayer(klayer):
        # we should use get_weights instead of klayer.weights
        return WeightsConverter.to_bigdl_weights(klayer, klayer.get_weights())

    @staticmethod
    def get_weights_from_kmodel(kmodel):
        """
        Convert kmodel's weights to bigdl format.
        We are supposing the order is the same as the execution order.
        :param kmodel: keras model
        :return: list of ndarray
        """
        layers_with_weights = [layer for layer in kmodel.layers if layer.weights]
        bweights = []
        for klayer in layers_with_weights:
            # bws would be [weights, bias] or [weights]
            bws = WeightsConverter.get_bigdl_weights_from_klayer(klayer)
            for w in bws:
                bweights.append(w)
        return bweights

    @staticmethod
    def convert_timedistributed(klayer, weights):
        return WeightsConverter.to_bigdl_weights(klayer.layer, weights)

    @staticmethod
    def convert_bidirectional(klayer, weights):
        kweights_forward = weights[:int(len(weights) / 2)]
        kweights_backward = weights[int(len(weights) / 2):]
        bweights_forward = WeightsConverter.to_bigdl_weights(klayer.layer, kweights_forward)
        bweights_backward = WeightsConverter.to_bigdl_weights(klayer.layer, kweights_backward)
        return bweights_forward + bweights_backward

    @staticmethod
    def convert_dense(klayer, weights):
        return [np.transpose(weights[0]), weights[1]]

    @staticmethod
    def convert_timedistributeddense(klayer, weights):
        return [np.transpose(weights[0]), weights[1]]

    @staticmethod
    def convert_batchnormalization(klayer, weights):
        gamma = weights[0]
        beta = weights[1]
        return [gamma, beta]

    @staticmethod
    def convert_atrousconvolution2d(klayer, weights):
        return weights

    @staticmethod
    def convert_atrousconvolution1d(klayer, weights):
        return [np.transpose(weights[0], (3, 2, 0, 1)), weights[1]]

    @staticmethod
    def convert_deconvolution2d(klayer, weights):
        w = np.transpose(weights[0], (1, 0, 2, 3))
        weight = np.expand_dims(w, 0)
        if len(weights) > 1:
            return [weight, weights[1]]
        else:
            return [weight]

    @staticmethod
    def convert_convolution2d(klayer, weights):
        weight = np.expand_dims(weights[0], 0)  # bigdl has a leading dim with value 1
        if len(weights) > 1:
            return [weight, weights[1]]
        else:
            return [weight]

    @staticmethod
    def convert_convolution1d(klayer, weights):
        return WeightsConverter.convert_convolution2d(klayer, weights)

    @staticmethod
    def convert_convolution3d(klayer, weights):
        return weights

    @staticmethod
    def convert_embedding(klayer, weights):
        return weights

    @staticmethod
    def convert_simplernn(klayer, weights):
        return [np.transpose(weights[0]), np.transpose(weights[1]), weights[2]]

    @staticmethod
    def convert_lstm(klayer, weights):
        w1 = np.concatenate((weights[0].T, weights[3].T, weights[6].T, weights[9].T))
        w2 = np.concatenate((weights[2], weights[5], weights[8], weights[11]))
        w3 = np.concatenate((weights[1].T, weights[4].T, weights[7].T, weights[10].T))
        return [w1, w2, w3]

    @staticmethod
    def convert_convlstm2d(klayer, weights):
        return [np.expand_dims(weights[6], 0), weights[8], np.expand_dims(weights[7], 0),
                np.expand_dims(weights[0], 0), weights[2], np.expand_dims(weights[1], 0),
                np.expand_dims(weights[3], 0), weights[5], np.expand_dims(weights[4], 0),
                np.expand_dims(weights[9], 0), weights[11], np.expand_dims(weights[10], 0)]

    @staticmethod
    def convert_gru(klayer, weights):
        w1 = np.concatenate((weights[3].T, weights[0].T, weights[6].T))
        w2 = np.concatenate((weights[5], weights[2], weights[8]))
        w3 = np.concatenate((weights[4].T, weights[1].T))
        w4 = weights[7].T
        return [w1, w2, w3, w4]

    @staticmethod
    def convert_highway(klayer, weights):
        if len(weights) == 2:  # if without bias
            return [weights[1].T, weights[0].T]
        return [weights[1].T, weights[3], weights[0].T, weights[2]]

    @staticmethod
    def convert_maxoutdense(klayer, weights):
        k_weights = weights[0]
        b_weights = k_weights[0].T
        for i in range(1, k_weights.shape[0]):
            b_weights = np.concatenate((b_weights, k_weights[i].T))
        if len(weights) == 1:  # if without bias
            return [b_weights]
        return [b_weights, weights[1].reshape(k_weights.shape[0]*k_weights.shape[2], )]

    @staticmethod
    def convert_srelu(klayer, weights):
        return weights

    @staticmethod
    def convert_separableconvolution2d(klayer, weights):
        if len(weights) == 2:  # if without bias
            if klayer.dim_ordering == "th":
                bias = weights[1].shape[0]
            else:
                bias = weights[1].shape[3]
            return [weights[0], weights[1], np.zeros(bias, )]
        return weights

    @staticmethod
    def convert_locallyconnected1d(klayer, weights):
        bweights1 = np.transpose(weights[0], (0, 2, 1))
        if len(weights) == 1:  # if without bias
            return [bweights1]
        return[bweights1, weights[1]]

    @staticmethod
    def convert_locallyconnected2d(klayer, weights):
        bweights1 = np.transpose(weights[0], (0, 2, 1))
        if len(weights) == 1:  # if without bias
            return [bweights1]
        bweights2 = weights[1].reshape(weights[1].shape[0]*weights[1].shape[1], weights[1].shape[2])
        return[bweights1, bweights2]


class DefinitionLoader:

    @staticmethod
    def __build_node_id_2_klayer(kmodel, node_id_to_config_layer):
        """
        The result would contain all of the layers including nested layers.
        :param kmodel: a keras model which can be Sequential or Model
        :param node_id_to_config_layer: a container to store the result
        """
        node_id_to_config_layer[kmodel.name] = kmodel  # include itself as well
        def gather_result(layers):
            if layers:  # layers maybe None here.
                for layer in layers:
                    if layer.name not in node_id_to_config_layer:
                        node_id_to_config_layer[layer.name] = layer
                        DefinitionLoader.__build_node_id_2_klayer(layer, node_id_to_config_layer)
        if hasattr(kmodel, "layers"):
            gather_result(kmodel.layers)
        if hasattr(kmodel, "flattened_layers"):
            gather_result(kmodel.flattened_layers)  # it's a expensive operation

    @staticmethod
    def __build_node_id_2_kclayer(kmodel, node_id_to_config_layer):
        if isinstance(kmodel, Sequential):
            for layer_config in kmodel.get_config():
                layer_name = layer_config["config"]["name"]
                node_id_to_config_layer[layer_name] = layer_config
        elif isinstance(kmodel, Model):
            for layerConfig in kmodel.get_config()["layers"]:
                node_id_to_config_layer[layerConfig["name"]] = layerConfig
        elif isinstance(kmodel, Layer):
            node_id_to_config_layer[kmodel.name] = kmodel.get_config()
        else:
            raise Exception("should not enter here: %s" % kmodel)

    def __init__(self, kmodel):
        self.node_id_to_instance = {}
        self.node_id_to_layer = {}
        self.node_id_to_config_layer = {}
        self.kmodel = kmodel
        self.kconfig = self.kmodel.get_config()

        DefinitionLoader.__build_node_id_2_klayer(kmodel, self.node_id_to_layer)
        DefinitionLoader.__build_node_id_2_kclayer(kmodel, self.node_id_to_config_layer)

    def __to_bigdl(self):
        if isinstance(self.kmodel, Sequential):
            bigdlmodel = self._construct_bigdl_sequence()
        elif isinstance(self.kmodel, Model):
            bigdlmodel = self._construct_bigdl_model()
        elif isinstance(self.kmodel, Layer):
            bigdlmodel = LayerConverter(self.kmodel,
                                        self.node_id_to_config_layer[self.kmodel.name]).create()
        else:
            raise Exception("Should not enter here: %s" % self.kmodel)
        return bigdlmodel

    @classmethod
    def from_kmodel(cls, kmodel):
        return cls(kmodel).__to_bigdl()

    @classmethod
    def from_hdf5_path(cls, hdf5_path):
        """
        :param hdf5_path: hdf5 path which can be stored in a local file system, HDFS, S3, or any Hadoop-supported file system.
        :return: BigDL Model
        """
        from keras.models import load_model
        hdf5_local_path = BCommon.get_local_file(hdf5_path)
        kmodel = load_model(hdf5_local_path)
        return kmodel, DefinitionLoader.from_kmodel(kmodel)

    @classmethod
    def from_json_path(cls, json_path):
        """
        :param json_path: definition path which can be stored in a local file system, HDFS, S3, or any Hadoop-supported file system.
        :return: BigDL Model
        """
        json_str = BCommon.text_from_path(json_path)
        return DefinitionLoader.from_json_str(json_str)

    @classmethod
    def from_json_str(cls, json_str):
        kmodel = model_from_json(json_str)
        return DefinitionLoader.from_kmodel(kmodel)

    def _do_create_node(self, layer, clayer):
        if clayer["class_name"] == "InputLayer":
            input = BLayer.Input()
            input.element().set_name(layer.name) # cannot set name for node?
            self.node_id_to_instance[layer.name] = input
            return input
        bigdl_in_nodes = []
        for node in clayer["inbound_nodes"]:
            for out in node:
                out_name = out[0]
                out_index = out[1]
                out_tensor_index = out[2]
                if out_name not in self.node_id_to_instance:
                    self._do_create_node(self.node_id_to_layer[out_name],
                                         self.node_id_to_config_layer[out_name])
                bigdl_in_nodes.append(self.node_id_to_instance[out_name])

        blayer = LayerConverter(layer, clayer).create()
        new_bnode = blayer(bigdl_in_nodes)
        self.node_id_to_instance[layer.name] = new_bnode
        return new_bnode

    def _construct_bigdl_model(self):
        for clayer in self.kconfig["layers"]:
            if clayer["name"] not in self.node_id_to_instance:
                self._do_create_node(self.node_id_to_layer[clayer["name"]],
                                     clayer)
        ins = []
        for input_layer in self.kconfig["input_layers"]:
            name = input_layer[0]
            ins.append(self.node_id_to_instance[name])
        outs = []
        for output_layer in self.kconfig["output_layers"]:
            name = output_layer[0]
            outs.append(self.node_id_to_instance[name])
        return BLayer.Model(inputs=ins, outputs=outs)

    def _construct_bigdl_sequence(self):
        bseq = BLayer.Sequential()
        for layer in self.kmodel.layers:
            # recursive logic is within create method.
            blayer = LayerConverter(layer, self.node_id_to_config_layer[layer.name]).create()
            bseq.add(blayer)
        return bseq

class LayerConverter:

    def __init__(self, klayer, kclayer, input_shape=None):
        self.klayer = klayer
        self.kclayer = kclayer
        if "config" in kclayer:
            self.config = kclayer["config"]
        else:
            self.config = {}
        if not input_shape:
            self.input_shape = klayer.get_input_shape_at(0)
        else:
            self.input_shape = input_shape

    def __check_is_share_weights(self):
        # For Merge layer len(kclayer["inbound_nodes"]) is equal to 1
        # "inbound_nodes": [
        #                      [
        #                          [
        #                              "batchnormalization_194",
        #                              0,
        #                              0
        #                          ],
        #                          [
        #                              "batchnormalization_196",
        #                              0,
        #                              0
        #                          ],
        #                          [
        #                              "batchnormalization_199",
        #                              0,
        #                              0
        #                          ],
        #                          [
        #                              "batchnormalization_200",
        #                              0,
        #                              0
        #                          ]
        #                      ]
        #                  ],
        if "inbound_nodes" in self.kclayer and len(self.kclayer["inbound_nodes"]) > 1:
            raise Exception(
                "%s doesn't support multiple inputs with shared weights" % self.kclayer["class_name"])

    def create(self):
        class_name = self.klayer.__class__.__name__

        self.__check_is_share_weights()

        if (hasattr(self.klayer, "b_constraint") and self.klayer.b_constraint) or \
           (hasattr(self.klayer, "W_constraint") and self.klayer.W_constraint):
            raise Exception("We don't support constraint for now")

        if hasattr(self.klayer, "activity_regularizer") and self.klayer.activity_regularizer:
            raise Exception("We don't support activity_regularizer for now")

        function_name = "create_" + class_name.lower()
        if not hasattr(self, function_name):
            raise Exception("We don't support layer: %s for now" % class_name )

        blayer_creator = getattr(self, function_name)
        blayer = blayer_creator()
        return blayer.set_name(self.klayer.name)

    def create_model(self):
        return DefinitionLoader.from_kmodel(self.klayer)

    def create_sequential(self):
        return DefinitionLoader.from_kmodel(self.klayer)

    def create_inputlayer(self):
        return BLayer.Identity()

    def create_dense(self):
        # Multiple inputs should share the same input_dim for Dense layer
        # We don't need to respect the tensor index for method `get_input_shape_at`
        # which is internal implementation and `get_input_shape_at` has hided that for us,
        # What we need to use is the input index, not node index, not tensor index.

        out_dim = self.config["output_dim"]
        in_dim = int(self.input_shape[-1])
        blayer = BLayer.Linear(
            input_size=in_dim,
            output_size=out_dim,
            with_bias=self.config["bias"],
            wRegularizer=to_bigdl_reg(self.config["W_regularizer"]),
            bRegularizer=to_bigdl_reg(self.config["b_regularizer"])
        )

        if len(self.input_shape) <= 2:
            return self.combo_parameter_layer(blayer, self.config)
        else:
            seq = BLayer.Sequential()
            seq.add(BLayer.InferReshape([-1, in_dim], False))
            seq.add(blayer)
            seq.add(BLayer.InferReshape([-1] + list(self.input_shape[1:-1]) + [out_dim], False))
            return self.combo_parameter_layer(seq, self.config)

    def create_timedistributeddense(self):
        blayer = BLayer.TimeDistributed(BLayer.Linear(
            input_size=int(self.input_shape[-1]),
            output_size=self.config["output_dim"],
            with_bias=self.config["bias"],
            wRegularizer=to_bigdl_reg(self.config["W_regularizer"]),
            bRegularizer=to_bigdl_reg(self.config["b_regularizer"])
        ))
        return self.combo_parameter_layer(blayer, self.config)

    def create_timedistributed(self):
        # input_shape is (batch, time, other dims)
        inner_input_shape = (self.input_shape[0], ) + self.input_shape[2:]
        blayer = LayerConverter(self.klayer.layer, self.config['layer'], inner_input_shape).create()
        return BLayer.TimeDistributed(blayer)

    def create_bidirectional(self):
        if not self.klayer.layer.return_sequences:
            raise Exception("Only return_sequences=True is supported for RNNs for now")
        recurrent_name = "generate_" + self.klayer.layer.__class__.__name__.lower() + "_cell"

        recurrent_creator = getattr(self, recurrent_name)
        recurrent = recurrent_creator(self.klayer.layer, self.config['layer'], self.input_shape)
        if self.klayer.merge_mode == "concat":
            merge = BLayer.JoinTable(len(self.input_shape) - 1, len(self.input_shape) - 1)
        elif self.klayer.merge_mode == "sum":
            merge = BLayer.CAddTable()
        elif self.klayer.merge_mode == "mul":
            merge = BLayer.CMulTable()
        elif self.klayer.merge_mode == "ave":
            merge = BLayer.CAveTable()
        else:
            raise Exception("Invalid merge mode: %s" % self.klayer.merge_mode)
        blayer = BLayer.BiRecurrent(merge).add(recurrent)
        return blayer

    def create_embedding(self):
        seq_len = int(self.input_shape[1])
        if self.klayer.input_length and self.klayer.input_length != seq_len:
            raise Exception(
                "The input_length doesn't match: %s vs %s" % (seq_len, self.klayer.input_length))

        if hasattr(self.klayer, "dropout") and self.klayer.dropout != 0:
            raise Exception("We don't support dropout for now")

        if hasattr(self.klayer, "mask_zero") and self.klayer.mask_zero != False:
            raise Exception("We don't support mask_zero for now")

        bseq = BLayer.Sequential()
        blayer = BLayer.LookupTable(
                 n_index=self.klayer.input_dim,
                 n_output=self.klayer.output_dim,
                 padding_value=0.0,
                 norm_type=2.0,
                 should_scale_grad_by_freq=False,
                 wRegularizer=to_bigdl_reg(self.config["W_regularizer"]),
                 bigdl_type="float")
        bseq.add(BLayer.AddConstant(1.0))  # Add 1 as BigDL is one-based index
        bseq.add(blayer)
        blayer.set_init_method(to_bigdl_init(self.config["init"]))
        return bseq

    def create_activation(self):
        blayer = get_activation_by_name(self.config["activation"], self.klayer.name)

        # SoftMax is different between Keras and BigDL for 3D inputs
        if self.config["activation"] == "softmax" and len(self.input_shape) == 3:
            model = BLayer.Sequential()
            model.add(BLayer.Transpose([(1, 3)]))
            model.add(blayer)
            model.add(BLayer.Transpose([(1, 3)]))
            return model
        return blayer

    def create_dropout(self):
        return BLayer.Dropout(self.klayer.p)

    def create_flatten(self):
        return BLayer.Reshape([int(np.prod(self.input_shape[1:]))], None)

    def create_permute(self):
        swaps = self.__perm_to_pair(list(self.klayer.dims))
        swaps.reverse()
        swaps = map(lambda pair: (pair[0]+1, pair[1]+1), swaps)
        return BLayer.Transpose(list(swaps))

    def __perm_to_pair(self, perm):
        # perm: a list as a permutation of [1..n], eg [3, 1, 2] for n=3.
        # return a list of tuples that needs to be swapped to obtain the input `perm`.
        pairs = []

        def sort(arr, low, high):
            i = low
            j = high
            pivot = arr[low + int((high - low) / 2)]
            while i <= j:
                while arr[i] < pivot:
                    i += 1
                while arr[j] > pivot:
                    j -= 1
                if i <= j:
                    exchangeNumbers(arr, i, j)
                    i += 1
                    j -= 1
            if low < j:
                sort(arr, low, j)
            if i < high:
                sort(arr, i, high)

        def exchangeNumbers(arr, i, j):
            temp = arr[i]
            arr[i] = arr[j]
            arr[j] = temp
            pairs.append((i + 1, j + 1))

        sort(perm, 0, len(perm) - 1)

        return list(filter(lambda pair: pair[0] != pair[1], pairs))

    def create_reshape(self):
        if -1 in self.klayer.target_shape:
            blayer = BLayer.InferReshape(self.klayer.target_shape, True)
        else:
            blayer = BLayer.Reshape(self.klayer.target_shape, None)
        return blayer

    def create_repeatvector(self):
        return BLayer.Replicate(n_features=self.klayer.n,
                                n_dim=1,
                                bigdl_type="float")

    def __is_from_sequential(self):
        return "layers" in self.kclayer["config"] and hasattr(self.klayer, "layers") and self.klayer.layers is not None  # noqa

    def create_merge(self):
        if self.klayer.output_shape and not isinstance(self.klayer.output_shape, tuple):
            raise Exception("Only output_shape=None or a shape tuple is supported for now")
        if self.klayer.node_indices and not all(0 == i for i in self.klayer.node_indices):
            unsupport_exp("node_indices")
        if self.klayer.output_mask:
            unsupport_exp("output_mask")
        if self.klayer.mode == "concat":
            blayer = BLayer.JoinTable(
                dimension=self.klayer.concat_axis,
                n_input_dims=len(self.input_shape[0]) - 1,
                bigdl_type="float")
        elif self.klayer.mode == "sum":
            blayer = BLayer.CAddTable(
                inplace=False,
                bigdl_type="float")
        elif self.klayer.mode == "mul":
            blayer = BLayer.CMulTable(bigdl_type="float")
        elif self.klayer.mode == "max":
            blayer = BLayer.CMaxTable(bigdl_type="float")
        elif self.klayer.mode == "dot":
            if len(self.input_shape[0]) >= 3:
                raise Exception("For merge mode dot, 3D input or above is not supported for now.")
            if self.klayer.dot_axes != [1, 1]:
                raise Exception("For merge mode dot, only dot_axes=1 is supported for now.")
            model = BLayer.Sequential()
            blayer = model.add(BLayer.DotProduct(bigdl_type="float"))\
                .add(BLayer.Reshape([1], True))
        elif self.klayer.mode == "ave":
            blayer = BLayer.CAveTable(
                inplace=False,
                bigdl_type="float")
        elif self.klayer.mode in ['cos']:
            if len(self.input_shape[0]) >= 3:
                raise Exception("For merge mode cos, 3D input or above is not supported for now.")
            if self.klayer.dot_axes != [1, 1]:
                raise Exception("For merge mode cos, only dot_axes=1 is supported for now.")
            blayer = BLayer.Sequential()
            blayer.add(BLayer.CosineDistance(bigdl_type="float")).add(BLayer.Reshape([1, 1], True))
        else:  # invalid mode or lambda functions
            raise Exception("Invalid merge mode: `%s`. Lambda/function as merge mode is not supported for now."
                            % self.klayer.mode)
        if self.__is_from_sequential():
            bseq = BLayer.Sequential()
            parallel_table = BLayer.ParallelTable()
            for l in self.klayer.layers:
                bl = DefinitionLoader.from_kmodel(l)
                parallel_table.add(bl)
            bseq.add(parallel_table)
            bseq.add(blayer)
            return bseq
        else:
            return blayer

    def create_elu(self):
        return BLayer.ELU(alpha=float(self.klayer.alpha),
                          inplace=False,
                          bigdl_type="float")

    def create_prelu(self):
        return BLayer.PReLU(n_output_plane=0,
                            bigdl_type="float")

    def create_leakyrelu(self):
        return BLayer.LeakyReLU(negval=float(self.klayer.alpha),
                                inplace=False,
                                bigdl_type="float")

    def create_parametricsoftplus(self):
        alpha = float(self.klayer.alpha_init)
        beta = float(self.klayer.beta_init)
        if self.klayer.shared_axes != [None]:
            unsupport_exp("shared_axes")
        if round(alpha * beta, 4) == 1.0:
            return BLayer.SoftPlus(beta=beta,
                                   bigdl_type="float")
        else:
            raise Exception("Only alpha_init = 1/beta_init is supported for now")

    def create_thresholdedrelu(self):
        return BLayer.Threshold(th=float(self.klayer.theta),
                                v=0.0,
                                ip=False,
                                bigdl_type="float")

    def __generate_zeropadding1d(self, pad_top, pad_bottom):
        return BLayer.SpatialZeroPadding(pad_left=0,
                                         pad_right=0,
                                         pad_top=pad_top,
                                         pad_bottom=pad_bottom,
                                         bigdl_type="float")

    def create_zeropadding1d(self):
        padding = self.klayer.padding
        if isinstance(padding, int):
            return self.__generate_zeropadding1d(padding, padding)
        elif isinstance(padding, dict):
            return self.__generate_zeropadding1d(padding.get('left_pad', 0), padding.get('right_pad', 0))
        else:  # tuple of int (length 2)
            padding = tuple(padding)
            return self.__generate_zeropadding1d(padding[0], padding[1])

    def __generate_zeropadding2d(self, dim1, dim2, n_input_dim, pad1, pad2, pad3, pad4):
        model = BLayer.Sequential()
        paddinglayer1 = BLayer.Padding(dim=dim1,
                                       pad=pad1,
                                       n_input_dim=n_input_dim,
                                       value=0.0,
                                       n_index=1,
                                       bigdl_type="float")
        paddinglayer2 = BLayer.Padding(dim=dim1,
                                       pad=pad2,
                                       n_input_dim=n_input_dim,
                                       value=0.0,
                                       n_index=1,
                                       bigdl_type="float")
        paddinglayer3 = BLayer.Padding(dim=dim2,
                                       pad=pad3,
                                       n_input_dim=n_input_dim,
                                       value=0.0,
                                       n_index=1,
                                       bigdl_type="float")
        paddinglayer4 = BLayer.Padding(dim=dim2,
                                       pad=pad4,
                                       n_input_dim=n_input_dim,
                                       value=0.0,
                                       n_index=1,
                                       bigdl_type="float")
        model.add(paddinglayer1)
        model.add(paddinglayer2)
        model.add(paddinglayer3)
        model.add(paddinglayer4)
        return model

    # NB: zeropadding doesn't serialize dim_ording to json file
    def create_zeropadding2d(self):
        padding = self.klayer.padding
        dim = 1
        if "dim_ordering" not in self.config:
            warnings.warn("Cannot find dim_ordering from json definition. Using the default instead.")
        if self.klayer.dim_ordering == "th":
            dim = 2
        if isinstance(padding, dict):  # dictionary
            return self.__generate_zeropadding2d(dim, dim+1, len(self.input_shape) - 1,
                                                 -padding.get('top_pad', 0), padding.get('bottom_pad', 0),
                                                 -padding.get('left_pad', 0), padding.get('right_pad', 0))
        else:  # tuple of int
            padding = tuple(padding)
            if len(padding) == 2:
                return self.__generate_zeropadding2d(dim, dim+1, len(self.input_shape) - 1,
                                                     -padding[0], padding[0], -padding[1], padding[1])
            elif len(padding) == 4:
                return self.__generate_zeropadding2d(dim, dim+1, len(self.input_shape) - 1,
                                                     -padding[0], padding[1], -padding[2], padding[3])

    # NB: zeropadding doesn't serialize dim_ording to json file
    def create_zeropadding3d(self):
        padding = tuple(self.klayer.padding)
        dim = 1
        if "dim_ordering" not in self.config:
            warnings.warn("Cannot find dim_ordering from json definition. Using the default instead.")
        if self.klayer.dim_ordering == "th":
            dim = 2
        model = BLayer.Sequential()
        paddinglayer1 = BLayer.Padding(dim=dim,
                                       pad=-padding[0],
                                       n_input_dim=len(self.input_shape) - 1,
                                       value=0.0,
                                       n_index=1,
                                       bigdl_type="float")
        paddinglayer2 = BLayer.Padding(dim=dim,
                                       pad=padding[0],
                                       n_input_dim=len(self.input_shape) - 1,
                                       value=0.0,
                                       n_index=1,
                                       bigdl_type="float")
        paddinglayer3 = BLayer.Padding(dim=dim+1,
                                       pad=-padding[1],
                                       n_input_dim=len(self.input_shape) - 1,
                                       value=0.0,
                                       n_index=1,
                                       bigdl_type="float")
        paddinglayer4 = BLayer.Padding(dim=dim+1,
                                       pad=padding[1],
                                       n_input_dim=len(self.input_shape) - 1,
                                       value=0.0,
                                       n_index=1,
                                       bigdl_type="float")
        paddinglayer5 = BLayer.Padding(dim=dim+2,
                                       pad=-padding[2],
                                       n_input_dim=len(self.input_shape) - 1,
                                       value=0.0,
                                       n_index=1,
                                       bigdl_type="float")
        paddinglayer6 = BLayer.Padding(dim=dim+2,
                                       pad=padding[2],
                                       n_input_dim=len(self.input_shape) - 1,
                                       value=0.0,
                                       n_index=1,
                                       bigdl_type="float")
        model.add(paddinglayer1)
        model.add(paddinglayer2)
        model.add(paddinglayer3)
        model.add(paddinglayer4)
        model.add(paddinglayer5)
        model.add(paddinglayer6)
        return model

    def create_cropping1d(self):
        cropping = tuple(self.klayer.cropping)
        return BLayer.SpatialZeroPadding(0, 0, -cropping[0], -cropping[1])

    def create_cropping2d(self):
        bigdl_order = self.get_bdim_order()
        blayer = BLayer.Cropping2D(heightCrop=self.klayer.cropping[0],
                                   widthCrop=self.klayer.cropping[1],
                                   data_format=bigdl_order)
        return blayer

    def create_cropping3d(self):
        bigdl_order = self.get_bdim_order("3D")
        blayer = BLayer.Cropping3D(dim1Crop=self.klayer.cropping[0],
                                   dim2Crop=self.klayer.cropping[1],
                                   dim3Crop=self.klayer.cropping[2],
                                   data_format=bigdl_order)
        return blayer

    def __check_recurrent_parameters(self, klayer):
        if klayer.stateful:
            raise Exception("Only stateful=False for recurrent layers is supported for now")
        if hasattr(klayer, "consume_less") and klayer.consume_less == "gpu":
            raise Exception("consume_less=gpu is not supported for now")

    def __process_recurrent_layer(self, return_sequences, go_backwards, blayer):
        # For recurrent layers,
        # handle whether to return the last output sentence or the full sequence;
        # handle whether the input will go backwards
        model = BLayer.Sequential()
        if go_backwards:
            model.add(BLayer.Reverse(2))
        model.add(blayer)
        if not return_sequences:
            model.add(BLayer.Select(2, -1))
        return model

    def generate_simplernn_cell(self, klayer, kclayer, input_shape):  # create a simplernn cell only
        self.__check_recurrent_parameters(klayer)
        config = kclayer["config"]
        activation = get_activation_by_name(config["activation"],
                                            "%s_%s" % (config["name"], config["activation"]))
        rnn = BLayer.RnnCell(input_size=int(input_shape[2]),
                             hidden_size=klayer.output_dim,
                             activation=activation,
                             isInputWithBias=False,
                             wRegularizer=to_bigdl_reg(config["W_regularizer"]),
                             uRegularizer=to_bigdl_reg(config["U_regularizer"]),
                             bRegularizer=to_bigdl_reg(config["b_regularizer"]),
                             bigdl_type="float")
        return rnn

    def create_simplernn(self):
        rec = BLayer.Recurrent()
        rnn = self.generate_simplernn_cell(self.klayer, self.kclayer, self.input_shape)
        return self.__process_recurrent_layer(self.klayer.return_sequences,
                                              self.klayer.go_backwards, rec.add(rnn))

    def generate_lstm_cell(self, klayer, kclayer, input_shape):  # create a lstm cell only
        self.__check_recurrent_parameters(klayer)
        config = kclayer["config"]
        activation = get_activation_by_name(config["activation"],
                                            "%s_%s" % (config["name"], config["activation"]))
        inner_activation = get_activation_by_name(config["inner_activation"],
                                                  "%s_%s" % (config["name"], config["inner_activation"]))
        lstm = BLayer.LSTM(input_size=int(input_shape[2]),
                           hidden_size=klayer.output_dim,
                           p=0.0,
                           activation=activation,
                           inner_activation=inner_activation,
                           wRegularizer=to_bigdl_reg(config["W_regularizer"]),
                           uRegularizer=to_bigdl_reg(config["U_regularizer"]),
                           bRegularizer=to_bigdl_reg(config["b_regularizer"]),
                           bigdl_type="float")
        return lstm

    def create_lstm(self):
        rec = BLayer.Recurrent()
        lstm = self.generate_lstm_cell(self.klayer, self.kclayer, self.input_shape)
        return self.__process_recurrent_layer(self.klayer.return_sequences,
                                              self.klayer.go_backwards, rec.add(lstm))

    def generate_convlstm2d_cell(self, klayer, kclayer, input_shape):  # create a convlstm2d cell only
        self.__check_recurrent_parameters(klayer)
        config = kclayer["config"]
        activation = get_activation_by_name(config["activation"],
                                            "%s_%s" % (config["name"], config["activation"]))
        inner_activation = get_activation_by_name(config["inner_activation"],
                                                  "%s_%s" % (config["name"], config["inner_activation"]))

        convlstm = BLayer.ConvLSTMPeephole(input_size=int(input_shape[2]),
                                           output_size=config["nb_filter"],
                                           kernel_i=config["nb_col"],
                                           kernel_c=config["nb_row"],
                                           # NB: ConvLSTM doesn't serialize subsample to json file
                                           stride=klayer.subsample[0],
                                           padding=-1,
                                           activation=activation,
                                           inner_activation=inner_activation,
                                           # NB: ConvLSTM doesn't serialize regularizers to json file
                                           # wRegularizer=to_bigdl_reg(config["W_regularizer"]),
                                           # uRegularizer=to_bigdl_reg(config["U_regularizer"]),
                                           # bRegularizer=to_bigdl_reg(config["b_regularizer"]),
                                           cRegularizer=None,
                                           with_peephole=False,
                                           bigdl_type="float")
        return convlstm

    def create_convlstm2d(self):
        # TODO: border_mode = 'valid'
        if self.config["border_mode"] != 'same':
            raise Exception("Unsupported border_mode: valid")

        if self.klayer.dim_ordering != "th":
            raise Exception("Please use `th` for `dim_ordering`. `%s` is not supported for now."
                            % self.klayer.dim_ordering)
        if self.config["nb_row"] != self.config["nb_col"]:
            raise Exception("Only square kernel is supported for now. Please set nb_row=nb_col.")
        if self.klayer.subsample[0] != self.klayer.subsample[1]:
            raise Exception("Only equal stride is supported for now. "
                            "Please set subsample to be a tuple with equal values.")

        rec = BLayer.Recurrent()
        convlstm = self.generate_convlstm2d_cell(self.klayer,
                                                 self.kclayer, self.input_shape)
        return self.__process_recurrent_layer(self.klayer.return_sequences,
                                              self.klayer.go_backwards, rec.add(convlstm))

    def generate_gru_cell(self, klayer, kclayer, input_shape):  # create a gru cell only
        self.__check_recurrent_parameters(klayer)
        config = kclayer["config"]
        activation = get_activation_by_name(config["activation"],
                                            "%s_%s" % (config["name"], config["activation"]))
        inner_activation = get_activation_by_name(config["inner_activation"],
                                                  "%s_%s" % (config["name"], config["inner_activation"]))
        gru = BLayer.GRU(input_size=int(input_shape[2]),
                         hidden_size=klayer.output_dim,
                         p=0.0,
                         activation=activation,
                         inner_activation=inner_activation,
                         wRegularizer=to_bigdl_reg(config["W_regularizer"]),
                         uRegularizer=to_bigdl_reg(config["U_regularizer"]),
                         bRegularizer=to_bigdl_reg(config["b_regularizer"]),
                         bigdl_type="float")
        return gru

    def create_gru(self):
        rec = BLayer.Recurrent()
        gru = self.generate_gru_cell(self.klayer, self.kclayer, self.input_shape)
        return self.__process_recurrent_layer(self.klayer.return_sequences,
                                              self.klayer.go_backwards, rec.add(gru))

    def create_batchnormalization(self):
        if len(self.input_shape) != 4:
            raise Exception("Only 4D input is supported for now, but the current input dim is %s",
                            len(self.input_shape))
        if keras.backend.image_dim_ordering() == "th" and self.klayer.axis != 1:
            raise Exception("""For BatchNormalization with th image ordering, we only support """ +
                            """axis = 1 for now, but the current axis is %s
                            """ % self.klayer.axis)  # noqa
        if keras.backend.image_dim_ordering() == "tf" and self.klayer.axis != -1:
            raise Exception("""For BatchNormalization with tf image ordering, we only support """ +
                            """axis = -1 for now, but the current axis is %s
                            """ % self.klayer.axis)
        if self.klayer.mode != 0:
            raise Exception(
                "Only support mode = 0 for now, but the current mode is: %s", self.klayer.mode)

        if self.config["gamma_regularizer"]:
            raise Exception("We don't support gamma_regularizer for now")

        if self.config["beta_regularizer"]:
            raise Exception("We don't support beta_regularizer for now")

        bigdl_order = to_bigdl_2d_ordering(keras.backend.image_dim_ordering())
        n_input_channel = int(self.input_shape[self.klayer.axis])

        # init gamma and beta
        # TODO: replace this with to_bigdl_init in the future
        gamma = self.get_value_from_init(self.klayer.gamma_init.__name__, (n_input_channel,))
        beta = self.get_value_from_init(self.klayer.beta_init.__name__, (n_input_channel,))

        blayer = BLayer.SpatialBatchNormalization(
                 n_output=n_input_channel,
                 eps=self.klayer.epsilon,
                 momentum=self.klayer.momentum,
                 affine=True,
                 init_weight=gamma,
                 init_bias=beta,
                 init_grad_weight=None,
                 init_grad_bias=None,
                 data_format=bigdl_order,
                 bigdl_type="float")

        k_running_mean = keras.backend.eval(self.klayer.running_mean)
        k_running_std = keras.backend.eval(self.klayer.running_std)
        blayer.set_running_mean(k_running_mean)
        blayer.set_running_std(k_running_std)
        return blayer

    def get_bdim_order(self, dim="2D"):  # get bigdl dim_ordering from keras dim_ordering
        if "dim_ordering" in self.config:
            order = self.config["dim_ordering"]
        else:
            warnings.warn("Cannot find dim_ordering from json definition. Using the default instead.")
            order = keras.backend.image_dim_ordering()
        if dim == "3D":
            return to_bigdl_3d_ordering(order)
        return to_bigdl_2d_ordering(order)

    def create_convolution1d(self):
        # batch, steps, dim, batch is None here, so you cannot use it directly.
        stack_size = int(self.input_shape[2])

        bpadW, bpadH = to_bigdl_2d_padding(self.klayer.border_mode)
        seq = BLayer.Sequential()
        seq.add(BLayer.Reshape([int(self.input_shape[1]), 1, int(self.input_shape[2])], True))
        blayer = BLayer.SpatialConvolution(
                 n_input_plane=stack_size,
                 n_output_plane=self.klayer.nb_filter,
                 kernel_w=1,
                 kernel_h=self.klayer.filter_length,
                 stride_w=1,
                 stride_h=self.klayer.subsample_length,
                 pad_w=bpadW,
                 pad_h=bpadH,
                 n_group=1,
                 propagate_back=True,
                 wRegularizer=to_bigdl_reg(self.config["W_regularizer"]),
                 bRegularizer=to_bigdl_reg(self.config["b_regularizer"]),
                 init_weight=None,
                 init_bias=None,
                 init_grad_weight=None,
                 init_grad_bias=None,
                 with_bias=self.config["bias"],
                 data_format="NHWC",
                 bigdl_type="float")
        seq.add(blayer)
        seq.add(BLayer.Squeeze(3))
        return self.combo_parameter_layer(seq, self.config)

    def create_convolution2d(self):
        bigdl_order = self.get_bdim_order()

        if bigdl_order == "NCHW":
            stack_size = int(self.input_shape[1])
        elif bigdl_order == "NHWC":
            stack_size = int(self.input_shape[3])

        bpadW, bpadH = to_bigdl_2d_padding(self.klayer.border_mode)
        blayer = BLayer.SpatialConvolution(
                 n_input_plane=stack_size,
                 n_output_plane=self.klayer.nb_filter,
                 kernel_w=self.klayer.nb_col,
                 kernel_h=self.klayer.nb_row,
                 stride_w=self.klayer.subsample[1],
                 stride_h=self.klayer.subsample[0],
                 pad_w=bpadW,
                 pad_h=bpadH,
                 n_group=1,
                 propagate_back=True,
                 wRegularizer=to_bigdl_reg(self.config["W_regularizer"]),
                 bRegularizer=to_bigdl_reg(self.config["b_regularizer"]),
                 init_weight=None,
                 init_bias=None,
                 init_grad_weight=None,
                 init_grad_bias=None,
                 with_bias=self.config["bias"],
                 data_format=bigdl_order,
                 bigdl_type="float")

        return self.combo_parameter_layer(blayer, self.config)

    def create_convolution3d(self):
        if self.klayer.dim_ordering != "th":
            raise Exception("Please use `th` for `dim_ordering`. `%s` is not supported for now." % self.klayer.dim_ordering)

        bpadT, bpadW, bpadH = to_bigdl_3d_padding(self.klayer.border_mode)
        blayer = BLayer.VolumetricConvolution(
            n_input_plane=int(self.input_shape[1]),
            n_output_plane=self.klayer.nb_filter,
            k_t=self.klayer.kernel_dim1,
            k_w=self.klayer.kernel_dim3,
            k_h=self.klayer.kernel_dim2,
            d_t=self.klayer.subsample[0],
            d_w=self.klayer.subsample[2],
            d_h=self.klayer.subsample[1],
            pad_t=bpadT,
            pad_w=bpadW,
            pad_h=bpadH,
            with_bias=self.config["bias"],
            wRegularizer=to_bigdl_reg(self.config["W_regularizer"]),
            bRegularizer=to_bigdl_reg(self.config["b_regularizer"]),
            bigdl_type="float")

        return self.combo_parameter_layer(blayer, self.config)

    def create_atrousconvolution1d(self):
        if not self.config["bias"]:
            raise Exception("Only bias=True is supported for AtrousConvolution1D")

        h = int(self.input_shape[1])
        kh = self.config["filter_length"]
        dh = self.config["subsample_length"]
        dilation_h = self.config["atrous_rate"]
        pad_h, pad_w = to_bigdl_2d_padding(self.config["border_mode"], h, kh, dh, dilation_h)
        seq = BLayer.Sequential()
        seq.add(BLayer.Transpose([(2, 3)]))
        seq.add(BLayer.Reshape([int(self.input_shape[2]), int(self.input_shape[1]), 1], True))
        blayer = BLayer.SpatialDilatedConvolution(
            n_input_plane=int(self.input_shape[2]),
            n_output_plane=self.config["nb_filter"],
            kw=1,
            kh=kh,
            dw=1,
            dh=dh,
            pad_w=pad_w,
            pad_h=pad_h,
            dilation_w=1,
            dilation_h=dilation_h,
            wRegularizer=to_bigdl_reg(self.config["W_regularizer"]),
            bRegularizer=to_bigdl_reg(self.config["b_regularizer"]),
            bigdl_type="float")

        seq.add(blayer)
        seq.add(BLayer.Transpose([(2, 3)]))
        seq.add(BLayer.Squeeze(4))
        return self.combo_parameter_layer(seq, self.config)

    def create_atrousconvolution2d(self):
        if self.klayer.dim_ordering != "th":
            raise Exception("Please use `th` for `dim_ordering`. `%s` is not supported for now." % self.klayer.dim_ordering)
        if not self.config["bias"]:
            raise Exception("Only bias=True is supported for AtrousConvolution2D")

        h = int(self.input_shape[2])
        w = int(self.input_shape[3])
        kh = self.config["nb_row"]
        kw = self.config["nb_col"]
        dh = self.config["subsample"][0]
        dw = self.config["subsample"][1]
        dilation_h = self.config["atrous_rate"][0]
        dilation_w = self.config["atrous_rate"][1]
        pad_h, pad_w = to_bigdl_2d_padding(self.config["border_mode"], h, kh, dh, dilation_h,
                                                w, kw, dw, dilation_w)
        blayer = BLayer.SpatialDilatedConvolution(
            n_input_plane=int(self.input_shape[1]),
            n_output_plane=self.config["nb_filter"],
            kw=kw,
            kh=kh,
            dw=dw,
            dh=dh,
            pad_w=pad_w,
            pad_h=pad_h,
            dilation_w=dilation_w,
            dilation_h=dilation_h,
            wRegularizer=to_bigdl_reg(self.config["W_regularizer"]),
            bRegularizer=to_bigdl_reg(self.config["b_regularizer"]),
            bigdl_type="float")

        return self.combo_parameter_layer(blayer, self.config)

    def create_deconvolution2d(self):
        if self.klayer.dim_ordering != "th":
            raise Exception("Please use `th` for `dim_ordering`. `%s` is not supported for now." % self.klayer.dim_ordering)
        output_shape = self.config["output_shape"]

        h = int(self.input_shape[2])
        w = int(self.input_shape[3])
        kh = self.config["nb_row"]
        kw = self.config["nb_col"]
        dh = self.config["subsample"][0]
        dw = self.config["subsample"][1]
        output_h = output_shape[2]
        output_w = output_shape[3]
        pad_w = 0
        pad_h = 0
        if self.config["border_mode"] == "same":
            two_pad_h = (h - 1) * dh + kh - output_h  # 2 times pad_h
            two_pad_w = (w - 1) * dw + kw - output_w  # 2 times pad_w
            if two_pad_h % 2 == 0:  # we only support pad_h as an int
                pad_h = int(two_pad_h / 2)
            else:
                raise Exception("For same padding, we only support padding on both sides for now. "
                                "Please make `(input_row - 1) * subsample[0] + nb_row - output_row` an even integer.")
            if two_pad_w % 2 == 0:  # we only support pad_w as an int
                pad_w = int(two_pad_w / 2)
            else:
                raise Exception("For same padding, we only support padding on both sides for now. "
                                "Please make `(input_col - 1) * subsample[1] + nb_col - output_col` an even integer.")
        blayer = BLayer.SpatialFullConvolution(
            n_input_plane=int(self.input_shape[1]),
            n_output_plane=self.klayer.nb_filter,
            kw=self.klayer.nb_col,
            kh=self.klayer.nb_row,
            dw=self.klayer.subsample[1],
            dh=self.klayer.subsample[0],
            pad_w=pad_w,
            pad_h=pad_h,
            adj_w=0,
            adj_h=0,
            n_group=1,
            no_bias=not self.klayer.bias,
            wRegularizer=to_bigdl_reg(self.config["W_regularizer"]),
            bRegularizer=to_bigdl_reg(self.config["b_regularizer"]),
            bigdl_type="float")

        return self.combo_parameter_layer(blayer, self.config)

    def create_maxpooling3d(self):
        if self.klayer.dim_ordering != "th":
            raise Exception("Please use `th` for `dim_ordering`. `%s` is not supported for now." % klayer.dim_ordering)
        # TODO: border_mode = 'same'
        if self.klayer.border_mode == 'same':
            raise Exception("Unsupported border_mode: same")

        bpadT, bpadW, bpadH = to_bigdl_3d_padding(self.klayer.border_mode)
        blayer = BLayer.VolumetricMaxPooling(
                k_t=self.klayer.pool_size[0],
                k_w=self.klayer.pool_size[2],
                k_h=self.klayer.pool_size[1],
                d_t=self.klayer.strides[0],
                d_w=self.klayer.strides[2],
                d_h=self.klayer.strides[1],
                pad_t=bpadT,
                pad_w=bpadW,
                pad_h=bpadH,
                bigdl_type="float")
        return blayer

    def create_maxpooling2d(self):
        bigdl_order = self.get_bdim_order()
        bpadW, bpadH = to_bigdl_2d_padding(self.klayer.border_mode)
        blayer = BLayer.SpatialMaxPooling(
                 kw=self.klayer.pool_size[1],
                 kh=self.klayer.pool_size[0],
                 dw=self.klayer.strides[1],
                 dh=self.klayer.strides[0],
                 pad_w=bpadW,
                 pad_h=bpadH,
                 to_ceil=False,
                 format=bigdl_order,
                 bigdl_type="float")
        return blayer

    def create_globalmaxpooling3d(self):
        if self.klayer.dim_ordering == "th":
            b_kt = int(self.input_shape[2])
            b_kw = int(self.input_shape[4])
            b_kh = int(self.input_shape[3])
        else:
            raise Exception("Please use `th` for dim_ordering. `%s` is not supported for now." % self.klayer.dim_ordering)

        seq = BLayer.Sequential()
        blayer = BLayer.VolumetricMaxPooling(
                k_t=b_kt,
                k_w=b_kw,
                k_h=b_kh,
                d_t=1,
                d_w=1,
                d_h=1,
                pad_t=0,
                pad_w=0,
                pad_h=0,
                bigdl_type="float"
        )
        seq.add(blayer)
        seq.add(BLayer.Squeeze(5))
        seq.add(BLayer.Squeeze(4))
        seq.add(BLayer.Squeeze(3))

        return seq

    def create_globalaveragepooling3d(self):
        if self.klayer.dim_ordering == "th":
            b_kt = int(self.input_shape[2])
            b_kw = int(self.input_shape[4])
            b_kh = int(self.input_shape[3])
        else:
            raise Exception("Please use `th` for dim_ordering. `%s` is not supported for now." % self.klayer.dim_ordering)

        seq = BLayer.Sequential()
        blayer = BLayer.VolumetricAveragePooling(
                k_t=b_kt,
                k_w=b_kw,
                k_h=b_kh,
                d_t=1,
                d_w=1,
                d_h=1,
                pad_t=0,
                pad_w=0,
                pad_h=0,
                count_include_pad=False,
                bigdl_type="float"
        )
        seq.add(blayer)
        seq.add(BLayer.Squeeze(5))
        seq.add(BLayer.Squeeze(4))
        seq.add(BLayer.Squeeze(3))

        return seq

    def create_averagepooling2d(self):
        bigdl_order = self.get_bdim_order()
        bpadW, bpadH = to_bigdl_2d_padding(self.klayer.border_mode)
        blayer = BLayer.SpatialAveragePooling(
            kw=self.klayer.pool_size[1],
            kh=self.klayer.pool_size[0],
            dw=self.klayer.strides[1],
            dh=self.klayer.strides[0],
            pad_w=bpadW,
            pad_h=bpadH,
            global_pooling=False,
            ceil_mode=False,
            count_include_pad=False,
            divide=True,
            format=bigdl_order,
            bigdl_type="float"
        )
        return blayer

    def create_averagepooling3d(self):
        if self.klayer.dim_ordering != "th":
            raise Exception("Please use `th` for `dim_ordering`. `%s` is not supported for now." % klayer.dim_ordering)
        # TODO: border_mode = 'same'
        if self.klayer.border_mode == 'same':
            raise Exception("Unsupported border_mode: same")

        bpadT, bpadW, bpadH = to_bigdl_3d_padding(self.klayer.border_mode)
        blayer = BLayer.VolumetricAveragePooling(
                k_t=self.klayer.pool_size[0],
                k_w=self.klayer.pool_size[2],
                k_h=self.klayer.pool_size[1],
                d_t=self.klayer.strides[0],
                d_w=self.klayer.strides[2],
                d_h=self.klayer.strides[1],
                pad_t=bpadT,
                pad_w=bpadW,
                pad_h=bpadH,
                count_include_pad=False,
                bigdl_type="float")
        return blayer

    def create_globalmaxpooling2d(self):
        bigdl_order = self.get_bdim_order()
        if bigdl_order == "NCHW":
            b_kw = int(self.input_shape[3])
            b_kh = int(self.input_shape[2])
        else:
            b_kw = int(self.input_shape[2])
            b_kh = int(self.input_shape[1])

        seq = BLayer.Sequential()
        blayer = BLayer.SpatialMaxPooling(
            kw=b_kw,
            kh=b_kh,
            dw=b_kw,
            dh=b_kh,
            pad_w=0,
            pad_h=0,
            to_ceil=False,
            format=bigdl_order,
            bigdl_type="float"
        )
        seq.add(blayer)
        if bigdl_order == "NCHW":
            seq.add(BLayer.Squeeze(3, num_input_dims=3))
            seq.add(BLayer.Squeeze(2, num_input_dims=2))
        else:
            seq.add(BLayer.Squeeze(2, num_input_dims=3))
            seq.add(BLayer.Squeeze(1, num_input_dims=2))
        return seq

    def create_globalmaxpooling1d(self):
        b_kw = 1
        b_kh = int(self.input_shape[1])

        seq = BLayer.Sequential()
        seq.add(BLayer.Reshape([int(self.input_shape[1]), 1, int(self.input_shape[2])], True))
        blayer = BLayer.SpatialMaxPooling(
            kw=b_kw,
            kh=b_kh,
            dw=1,
            dh=1,
            pad_w=0,
            pad_h=0,
            to_ceil=False,
            format="NHWC",
            bigdl_type="float"
        )
        seq.add(blayer)
        seq.add(BLayer.Squeeze(3))
        seq.add(BLayer.Squeeze(2))
        return seq

    def create_globalaveragepooling1d(self):
        b_kw = 1
        b_kh = int(self.input_shape[1])

        seq = BLayer.Sequential()
        seq.add(BLayer.Reshape([int(self.input_shape[1]), 1, int(self.input_shape[2])], True))
        blayer = BLayer.SpatialAveragePooling(
            kw=b_kw,
            kh=b_kh,
            dw=1,
            dh=1,
            pad_w=0,
            pad_h=0,
            global_pooling=False,
            ceil_mode=False,
            count_include_pad=False,
            divide=True,
            format="NHWC",
            bigdl_type="float"
        )
        seq.add(blayer)
        seq.add(BLayer.Squeeze(3))
        seq.add(BLayer.Squeeze(2))
        return seq

    def create_maxpooling1d(self):
        bpadW, bpadH = to_bigdl_2d_padding(self.klayer.border_mode)

        seq = BLayer.Sequential()
        seq.add(BLayer.Reshape([int(self.input_shape[1]), 1, int(self.input_shape[2])], True))
        blayer = BLayer.SpatialMaxPooling(
            kw=1,
            kh=self.klayer.pool_length,
            dw=1,
            dh=self.klayer.stride,
            pad_w=bpadW,
            pad_h=bpadH,
            to_ceil=False,
            format="NHWC",
            bigdl_type="float"
        )
        seq.add(blayer)
        seq.add(BLayer.Squeeze(3))
        return seq

    def create_averagepooling1d(self):
        bpadW, bpadH = to_bigdl_2d_padding(self.klayer.border_mode)

        seq = BLayer.Sequential()
        seq.add(BLayer.Reshape([int(self.input_shape[1]), 1, int(self.input_shape[2])], True))
        blayer = BLayer.SpatialAveragePooling(
            kw=1,
            kh=self.klayer.pool_length,
            dw=1,
            dh=self.klayer.stride,
            pad_w=bpadW,
            pad_h=bpadH,
            global_pooling=False,
            ceil_mode=False,
            count_include_pad=False,
            divide=True,
            format="NHWC",
            bigdl_type="float"
        )
        seq.add(blayer)
        seq.add(BLayer.Squeeze(3))
        return seq

    def create_globalaveragepooling2d(self):
        bigdl_order = self.get_bdim_order()
        if bigdl_order == "NCHW":
            b_kw = int(self.input_shape[3])
            b_kh = int(self.input_shape[2])
        else:
            b_kw = int(self.input_shape[2])
            b_kh = int(self.input_shape[1])

        seq = BLayer.Sequential()
        blayer = BLayer.SpatialAveragePooling(
            kw=b_kw,
            kh=b_kh,
            dw=b_kw,
            dh=b_kh,
            pad_w=0,
            pad_h=0,
            global_pooling=False,
            ceil_mode=False,
            count_include_pad=False,
            divide=True,
            format=bigdl_order,
            bigdl_type="float"
        )
        seq.add(blayer)
        if bigdl_order == "NCHW":
            seq.add(BLayer.Squeeze(3, num_input_dims=3))
            seq.add(BLayer.Squeeze(2, num_input_dims=2))
        else:
            seq.add(BLayer.Squeeze(2, num_input_dims=3))
            seq.add(BLayer.Squeeze(1, num_input_dims=2))
        return seq

    def create_upsampling1d(self):
        return BLayer.UpSampling1D(self.klayer.length)

    def create_upsampling2d(self):
        bigdl_order = self.get_bdim_order()
        return BLayer.UpSampling2D(size=self.klayer.size,
                                   data_format=bigdl_order)

    def create_upsampling3d(self):
        if self.klayer.dim_ordering != "th":
            raise Exception("Please use th for dim_ordering. %s is not supported for now." % self.klayer.dim_ordering)
        if "dim_ordering" not in self.config:
            warnings.warn("Cannot find dim_ordering from json definition. Using the default instead."
                          "We only support th for now.")
        return BLayer.UpSampling3D(self.klayer.size)

    def create_gaussiannoise(self):
        return BLayer.GaussianNoise(float(self.klayer.sigma))

    def create_gaussiandropout(self):
        return BLayer.GaussianDropout(float(self.klayer.p))

    def create_highway(self):
        if self.config["activation"] == 'linear':
            activation = None
        else:
            activation = get_activation_by_name(self.config["activation"],
                                                "%s_%s" % (self.config["name"], self.config["activation"]))
        blayer = BLayer.Highway(size=int(self.input_shape[1]),
                                with_bias=self.klayer.bias,
                                activation=activation,
                                wRegularizer=to_bigdl_reg(self.config["W_regularizer"]),
                                bRegularizer=to_bigdl_reg(self.config["b_regularizer"]))
        return blayer

    def create_maxoutdense(self):
        blayer = BLayer.Maxout(input_size=int(self.input_shape[1]),
                               output_size=self.klayer.output_dim,
                               maxout_number=self.klayer.nb_feature,
                               with_bias=self.klayer.bias,
                               w_regularizer=to_bigdl_reg(self.config["W_regularizer"]),
                               b_regularizer=to_bigdl_reg(self.config["b_regularizer"]))
        return blayer

    def create_masking(self):
        return BLayer.Masking(float(self.klayer.mask_value))

    def create_srelu(self):
        if "shared_axes" not in self.config:
            warnings.warn("Cannot find shared_axes from json definition. Using shared_axes=None instead.")
        shape = self.input_shape[1:]
        t_left_init = to_bigdl_init(self.config["t_left_init"])
        a_left_init = to_bigdl_init(self.config["a_left_init"])
        t_right_init = to_bigdl_init(self.config["t_right_init"])
        a_right_init = to_bigdl_init(self.config["a_right_init"])
        if self.klayer.shared_axes == [None]:
            srelu = BLayer.SReLU(shape)
        else:
            srelu = BLayer.SReLU(shape, self.klayer.shared_axes)

        srelu.set_init_method(t_left_init, a_left_init, t_right_init, a_right_init)
        return srelu

    def create_separableconvolution2d(self):
        if keras.backend.backend() != 'tensorflow':
            raise Exception('Please use tensorflow backend for keras 1.2.2 '
                            'if you want to load SeparableConv2D')
        bigdl_order = self.get_bdim_order()

        if bigdl_order == "NCHW":
            stack_size = int(self.input_shape[1])
        elif bigdl_order == "NHWC":
            stack_size = int(self.input_shape[3])

        bpadW, bpadH = to_bigdl_2d_padding(self.klayer.border_mode)
        blayer = BLayer.SpatialSeperableConvolution(
            n_input_channel=stack_size,
            n_output_channel=self.klayer.nb_filter,
            depth_multiplier=self.klayer.depth_multiplier,
            kernel_w=self.klayer.nb_col,
            kernel_h=self.klayer.nb_row,
            stride_w=self.klayer.subsample[1],
            stride_h=self.klayer.subsample[0],
            pad_w=bpadW,
            pad_h=bpadH,
            with_bias=self.klayer.bias,
            data_format=bigdl_order,
            w_regularizer=to_bigdl_reg(self.config["depthwise_regularizer"]),
            b_regularizer=to_bigdl_reg(self.config["b_regularizer"]),
            p_regularizer=to_bigdl_reg(self.config["pointwise_regularizer"])
        )

        return self.combo_parameter_layer(blayer, self.config)

    def create_activityregularization(self):
        return BLayer.ActivityRegularization(l1=self.klayer.l1, l2=self.klayer.l2)

    def create_spatialdropout1d(self):
        return BLayer.SpatialDropout1D(init_p=float(self.klayer.p))

    def create_spatialdropout2d(self):
        bigdl_order = self.get_bdim_order()
        blayer = BLayer.SpatialDropout2D(init_p=float(self.klayer.p),
                                         data_format=bigdl_order)
        return blayer

    def create_spatialdropout3d(self):
        bigdl_order = self.get_bdim_order()
        blayer = BLayer.SpatialDropout3D(init_p=float(self.klayer.p),
                                         data_format=bigdl_order)
        return blayer

    def create_locallyconnected1d(self):
        seq = BLayer.Sequential()
        seq.add(BLayer.Reshape([int(self.input_shape[1]), 1, int(self.input_shape[2])], True))
        blayer = BLayer.LocallyConnected2D(n_input_plane=int(self.input_shape[2]),
                                           input_width=1,
                                           input_height=int(self.input_shape[1]),
                                           n_output_plane=self.klayer.nb_filter,
                                           kernel_w=1,
                                           kernel_h=self.klayer.filter_length,
                                           stride_w=1,
                                           stride_h=self.klayer.subsample_length,
                                           pad_w=0,
                                           pad_h=0,
                                           wRegularizer=to_bigdl_reg(self.config["W_regularizer"]),
                                           bRegularizer=to_bigdl_reg(self.config["b_regularizer"]),
                                           with_bias=self.klayer.bias,
                                           data_format="NHWC")
        seq.add(blayer)
        seq.add(BLayer.Squeeze(3))
        if self.config["activation"] != "linear":
            activation = get_activation_by_name(self.config["activation"],
                                                "%s_%s" % (self.config["name"], self.config["activation"]))
            return self.fuse(seq, activation)
        else:
            return seq

    def create_locallyconnected2d(self):
        bigdl_order = self.get_bdim_order()

        if bigdl_order == "NCHW":
            stack_size = int(self.input_shape[1])
            input_width = int(self.input_shape[3])
            input_height = int(self.input_shape[2])
        elif bigdl_order == "NHWC":
            stack_size = int(self.input_shape[3])
            input_width = int(self.input_shape[2])
            input_height = int(self.input_shape[1])

        bpadW, bpadH = to_bigdl_2d_padding(self.klayer.border_mode)
        blayer = BLayer.LocallyConnected2D(n_input_plane=stack_size,
                                           input_width=input_width,
                                           input_height=input_height,
                                           n_output_plane=self.klayer.nb_filter,
                                           kernel_w=self.klayer.nb_col,
                                           kernel_h=self.klayer.nb_row,
                                           stride_w=self.klayer.subsample[1],
                                           stride_h=self.klayer.subsample[0],
                                           pad_w=bpadW,
                                           pad_h=bpadH,
                                           wRegularizer=to_bigdl_reg(self.config["W_regularizer"]),
                                           bRegularizer=to_bigdl_reg(self.config["b_regularizer"]),
                                           with_bias=self.klayer.bias,
                                           data_format=bigdl_order)

        if self.config["activation"] != "linear":
            activation = get_activation_by_name(self.config["activation"],
                                                "%s_%s" % (self.config["name"], self.config["activation"]))
            return self.fuse(blayer, activation)
        else:
            return blayer

    def combo_parameter_layer(self, blayer, config):
        blayer.set_name(config["name"])
        if hasattr(blayer, "set_init_method"):
            try:
                blayer.set_init_method(to_bigdl_init(config["init"]),
                                       BInit.Zeros())  # Keras always set this to be zeros
            except Exception:
                warning_msg = "We don't support initialization " + config["init"] + " for now. " \
                    + "Using the default instead."
                warnings.warn(warning_msg)
        # "linear" means doing nothing
        if config["activation"] != "linear":
            activation = get_activation_by_name(config["activation"],
                                                "%s_%s" % (config["name"], config["activation"]))
            return self.fuse(blayer, activation)
        else:
            return blayer

    def get_value_from_init(self, kinit_method, shape):
        if kinit_method == "zero":
            return np.zeros(shape)
        elif kinit_method == "one":
            return np.ones(shape)
        else:
            raise Exception("We don't support % for now", kinit_method)

    def fuse(self, src_blayer, activation):  # activation is a layer
        seq = BLayer.Sequential()
        seq.add(src_blayer)
        seq.add(activation)
        seq.set_name(src_blayer.name())
        return seq